# -*- coding:utf8 -*-
import matplotlib
import matplotlib.pyplot as plt
import numpy as np

matplotlib.use('TkAgg')
loss_data = {}
# llama = "0610_ms.txt"
# loss_data["llama"] = np.loadtxt(llama)
telechat = "0610_ms.txt"
loss_data["telechat_npu"] = np.loadtxt(telechat)
gpu = "0610_gpu.txt"
loss_data["telechat_gpu"] = np.loadtxt(gpu)[::2]

labels = list(loss_data.keys())
min_data_len = min([len(loss) for loss in list(loss_data.values())])
data_plot = [loss[:min_data_len] for loss in list(loss_data.values())]
axis = np.arange(0, min_data_len * 2, 2)
for d in data_plot:
    plt.plot(axis, d)
plt.legend(labels=labels)
plt.xlabel("step")
plt.ylabel("loss")
plt.title("telechat_3b [wechat_cleaned]")
# plt.annotate("min", xy=(-np.pi / 2, -1.0), xytext=((-np.pi / 2), -0.5), arrowprops=dict(arrowstyle="->"))
# plt.text(100, 3, "dataset:")
plt.savefig(f'./loss.png', dpi=1000, bbox_inches='tight')
plt.show()
